package com.kf.config.filter;


import com.kf.config.RedisUtil;
import com.kf.config.SecurityUser;
import com.kf.pojo.User;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.support.WebApplicationContextUtils;
import org.springframework.web.filter.OncePerRequestFilter;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;

/**
 * csrf过滤器
 *
 * @author Honey
 * @date 2022/12/23
 */
public class CsrfFilter extends OncePerRequestFilter {

    /**
     * 由于在过滤器中无法直接通过Autowired获取Bean,因此需要通过spring上下文来获取IOC管理的实体类
     * 此处主要是用于  RedisUserUtils

     *
     */
    public <T> T getBean(Class<T> clazz, HttpServletRequest request){
        WebApplicationContext applicationContext = WebApplicationContextUtils.getRequiredWebApplicationContext(request.getServletContext());
        return applicationContext.getBean(clazz);
    }

    /**
     * 需要排除的接口
     */
    private static final List<String> IGNORE_CSRF_LIST = new ArrayList<String>();

    private static final List<String> accessRequestList = Arrays.asList("GET", "HEAD", "TRACE", "OPTIONS");
    private Collection<String> domains;

    static {
        IGNORE_CSRF_LIST.add("/login/user");
        IGNORE_CSRF_LIST.add("/login/perms");
        IGNORE_CSRF_LIST.add("/login/phone");



    }

    public CsrfFilter(Collection<String> domains) {
        this.domains = domains;
    }

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
        String csrfToken = null;

        String csrf_token = request.getHeader("csrfToken");

        if(csrf_token!=null){
            csrfToken=csrf_token;
        }
        boolean missingToken = csrfToken == null;

        if (accessRequestList.contains(request.getMethod())) {
            filterChain.doFilter(request, response);
            return;
        }

        String uri = request.getRequestURI();
        System.out.println(uri);
        System.out.println(verifyIgnoreApi(uri));

        if (uri == null || verifyIgnoreApi(uri)) {
            System.out.println("是放行地址");
            filterChain.doFilter(request, response);
            return;
        }
        System.out.println("不是放行地址");

        if (!domains.isEmpty() && !verifyDomains(request)) {
            System.out.println("域名不合法");
            response.sendError(HttpServletResponse.SC_FORBIDDEN, "CSRF Protection: Referer Illegal");
            return;
        }
        System.out.println("域名合法");

        if (!verifyToken(request, csrfToken)) {
            System.out.println("token不一致");
            response.sendError(HttpServletResponse.SC_FORBIDDEN, missingToken ? "CSRF Token Missing" : "CSRF Token Invalid");
            return;
        }

        System.out.println("token一致");
        filterChain.doFilter(request, response);

    }

    private boolean verifyDomains(HttpServletRequest request) {
        // 从 HTTP 头中取得 Referer 值
        String referer = request.getHeader("Referer");
        // 判断 Referer 是否以 合法的域名 开头。
        if (referer != null) {
            if (referer.indexOf("://") > 0) {
                referer = referer.substring(referer.indexOf("://") + 3);
            }
            if (referer.indexOf("/") > 0) {
                referer = referer.substring(0, referer.indexOf("/"));
            }
            if (referer.indexOf(":") > 0) {
                referer = referer.substring(0, referer.indexOf(":"));
            }
            for (String domain : domains) {
                if (referer.endsWith(domain)) {
                    return true;
                }
            }
        }
        return false;
    }

    private boolean verifyToken(HttpServletRequest request, String csrfToken) {
        if (csrfToken == null) {
            System.out.println("token为null");
            return false;
        }
        System.out.println("csrfToken:"+csrfToken);

        HttpSession session = request.getSession();

        RedisUtil redisUtils =getBean(RedisUtil.class, request);
        User user = SecurityUser.getUser();
        User loginUser = (User)session.getAttribute("loginUser");
        System.out.println("当前登录用户"+user);
        System.out.println("当前登录用户session"+loginUser);
//        System.out.println("当前登录用户token"+token);

        String  token = (String) redisUtils.get((String.valueOf(loginUser.getUsername())));

        System.out.println("登录完了之后"+token);
        if(token.equals(csrfToken)){
            return true;
        }
        return false;
    }

    private boolean verifyIgnoreApi(String uri) {
        for (String ignoreApi : IGNORE_CSRF_LIST) {
            if (uri.endsWith(ignoreApi)) {
                return true;
            }
        }
        return false;
    }
}